{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 74,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.999999  , -0.00141019,  0.20437793], dtype=float32)"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAedUlEQVR4nO3df3BU5aH/8c9ufiyQsJsfkF0yJIUZuWKGH1V+br1Te0tKtKnVms5Yh6EpZXSkgQHpMDWt4uh0JgzeW6utYu94K87cYnrpFa0UtJmAoZY1QCQ1BEntFJtccBMlzW5Asvmxz/cPh/N1IWoCJPskvl8zZ8ac85zd5zyN++5mT6LLGGMEAICF3MmeAAAAn4RIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCslbRIPfnkk5oxY4YmTJigJUuW6NChQ8maCgDAUkmJ1G9/+1tt3LhRDz30kN58803Nnz9fJSUl6ujoSMZ0AACWciXjD8wuWbJEixYt0i9/+UtJUjweV0FBgdatW6f7779/tKcDALBU6mg/YW9vrxoaGlRZWensc7vdKi4uVigUGvScWCymWCzmfB2Px9XZ2anc3Fy5XK4RnzMA4Ooyxqi7u1v5+flyuz/5h3qjHqkPPvhAAwMD8vv9Cfv9fr9OnDgx6DlVVVV6+OGHR2N6AIBR1NbWpunTp3/i8VGP1OWorKzUxo0bna8jkYgKCwvV1tYmr9ebxJkBAC5HNBpVQUGBJk+e/KnjRj1SU6ZMUUpKitrb2xP2t7e3KxAIDHqOx+ORx+O5ZL/X6yVSADCGfdZHNqN+d196eroWLFig2tpaZ188Hldtba2CweBoTwcAYLGk/Lhv48aNKi8v18KFC7V48WL9/Oc/17lz57Rq1apkTAcAYKmkROrOO+/U+++/r82bNyscDuuLX/yiXnnllUtupgAAfL4l5fekrlQ0GpXP51MkEuEzKQAYg4b6Os7f7gMAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgrWFH6sCBA7r11luVn58vl8ulF198MeG4MUabN2/WtGnTNHHiRBUXF+udd95JGNPZ2akVK1bI6/UqKytLq1ev1tmzZ6/oQgAA48+wI3Xu3DnNnz9fTz755KDHt27dqieeeEJPP/206uvrlZGRoZKSEvX09DhjVqxYoebmZtXU1Gj37t06cOCA7rnnnsu/CgDA+GSugCSza9cu5+t4PG4CgYB59NFHnX1dXV3G4/GY559/3hhjzPHjx40kc/jwYWfM3r17jcvlMqdOnRrS80YiESPJRCKRK5k+ACBJhvo6flU/kzp58qTC4bCKi4udfT6fT0uWLFEoFJIkhUIhZWVlaeHChc6Y4uJiud1u1dfXD/q4sVhM0Wg0YQMAjH9XNVLhcFiS5Pf7E/b7/X7nWDgcVl5eXsLx1NRU5eTkOGMuVlVVJZ/P52wFBQVXc9oAAEuNibv7KisrFYlEnK2trS3ZUwIAjIKrGqlAICBJam9vT9jf3t7uHAsEAuro6Eg43t/fr87OTmfMxTwej7xeb8IGABj/rmqkZs6cqUAgoNraWmdfNBpVfX29gsGgJCkYDKqrq0sNDQ3OmH379ikej2vJkiVXczoAgDEudbgnnD17Vn/729+cr0+ePKnGxkbl5OSosLBQGzZs0E9/+lPNmjVLM2fO1IMPPqj8/HzdfvvtkqTrrrtON998s+6++249/fTT6uvr09q1a/Wd73xH+fn5V+3CAADjwHBvG9y/f7+RdMlWXl5ujPnoNvQHH3zQ+P1+4/F4zLJly0xLS0vCY5w5c8bcddddJjMz03i9XrNq1SrT3d191W9dBADYaaiv4y5jjEliIy9LNBqVz+dTJBLh8ykAGIOG+jo+Ju7uAwB8PhEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1hv0HZgGMDGOMek6d0od/+5v6zpxRvLdXKRMnyhMIKOPaa5WWnZ3sKQKjjkgBSWaMUV9np97fu1ddBw+q75//1EBPjzQwIFdqqlIyMuSZNk1TiouV/a//KvfEiXK5XMmeNjAqiBSQRMYYnTtxQm3/9V/68J13pIv+3rPp71d/JKL+SETnWloUbWrS9O99T+m5uUmaMTC6+EwKSKLzJ0+q7T//Ux/+9a+XBOoSxuifr7+u0//93+o9c2Z0JggkGZECkqSvs1Ond+zQh3//+9BPGhhQ54ED+uCPf1S8t3fkJgdYgkgBSWDicZ2pq1Pk0KHPfgd18bl9fXrv+efVc+rUCM0OsAeRApKgv6tLnfv3X9FjhF94QWPwv1kKDAuRApJgIBbT+XffvaLHOHf8+NWZDGAxIgWMMmOMokePJnsawJhApIAkCP/P/yR7CsCYQKQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1hhWpqqoqLVq0SJMnT1ZeXp5uv/12tbS0JIzp6elRRUWFcnNzlZmZqbKyMrW3tyeMaW1tVWlpqSZNmqS8vDxt2rRJ/f39V341AIBxZViRqqurU0VFhd544w3V1NSor69Py5cv17lz55wx9913n15++WXt3LlTdXV1On36tO644w7n+MDAgEpLS9Xb26uDBw/queee0/bt27V58+ard1UAgHHBZYwxl3vy+++/r7y8PNXV1enLX/6yIpGIpk6dqh07dujb3/62JOnEiRO67rrrFAqFtHTpUu3du1ff+MY3dPr0afn9fknS008/rR/96Ed6//33lZ6e/pnPG41G5fP5FIlE5PV6L3f6QFIYY9S0apX6Ojuv6HHSp07VnGeekcvlukozA0bPUF/Hr+gzqUgkIknKycmRJDU0NKivr0/FxcXOmNmzZ6uwsFChUEiSFAqFNHfuXCdQklRSUqJoNKrm5uZBnycWiykajSZsAIDx77IjFY/HtWHDBt14442aM2eOJCkcDis9PV1ZWVkJY/1+v8LhsDPm44G6cPzCscFUVVXJ5/M5W0FBweVOGwAwhlx2pCoqKnTs2DFVV1dfzfkMqrKyUpFIxNna2tpG/DkBAMmXejknrV27Vrt379aBAwc0ffp0Z38gEFBvb6+6uroS3k21t7crEAg4Yw4dOpTweBfu/rsw5mIej0cej+dypgoAGMOG9U7KGKO1a9dq165d2rdvn2bOnJlwfMGCBUpLS1Ntba2zr6WlRa2trQoGg5KkYDCopqYmdXR0OGNqamrk9XpVVFR0JdcCABhnhvVOqqKiQjt27NBLL72kyZMnO58h+Xw+TZw4UT6fT6tXr9bGjRuVk5Mjr9erdevWKRgMaunSpZKk5cuXq6ioSCtXrtTWrVsVDof1wAMPqKKigndLAIAEw4rUtm3bJElf+cpXEvY/++yz+t73vidJeuyxx+R2u1VWVqZYLKaSkhI99dRTztiUlBTt3r1ba9asUTAYVEZGhsrLy/XII49c2ZUAAMadK/o9qWTh96QwlvF7UsAo/Z4UAAAjiUgBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsFZqsicAINGpc+d0tLNT3X19mjphgoJTpyojLS3Z0wKSgkgBljDG6OTZs3ro6FG9e/asegYG5E1L05zsbP37okVKc/ODD3z+8F0PWOLvZ8/q7j//WW9HIjo/MCAjKdLXpz93dGh9fb3O9PQke4rAqCNSgCV+3tysSF/foMcOffCBak6fHuUZAclHpIAkcPEZEzAkRApIgms2b072FIAxgUgBo8zlcsmVeuk9S6UFBUpzuQY9Z0Zmpubl5Iz01ADrECkgCdJ8PuX8278l7CvJz9dD11+vCSkpzr+YKS6Xcj0e/ceiRSrKykoY7y8rG53JAknELehAErgnTlR2MKjIkSMa6O6W9NE7rJL8fE2fNEm7/+//dKanRzMyM3XnzJnK9XgSzvdMm6bsYFCuT3jnBYwXRApIApfLJe/112vqLbeo/X//V2ZgwNk/Jztbc7KzP/Hc1Oxs5a9cqVSvd7SmCyQNP+4DksTt8cj/rW8p56tfHfQzqsGkZGZq2p13KmvxYrlSUkZ4hkDy8U4KSKKUSZM0/fvfV1pOjjr371dvR8eg41wpKZpQUPBR1L7yFX7Mh88NIgUkkcvlUmpGhqZ9+9vyzpunfx48qLPNzYqFw4rHYkrJzNSE6dPlW7hQvoULNbGwkEDhc4VIARZwezzKnDNHGf/yLxo4f16mv18mHpcrJUXutDS5J02Se4g/EgTGE77rAUu4XC65PB65L7qTD/g848YJAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaw0rUtu2bdO8efPk9Xrl9XoVDAa1d+9e53hPT48qKiqUm5urzMxMlZWVqb29PeExWltbVVpaqkmTJikvL0+bNm1Sf3//1bkaAMC4MqxITZ8+XVu2bFFDQ4OOHDmir371q7rtttvU3NwsSbrvvvv08ssva+fOnaqrq9Pp06d1xx13OOcPDAyotLRUvb29OnjwoJ577jlt375dmzdvvrpXBQAYH8wVys7ONs8884zp6uoyaWlpZufOnc6xt99+20gyoVDIGGPMnj17jNvtNuFw2Bmzbds24/V6TSwWG/JzRiIRI8lEIpErnT4AIAmG+jp+2Z9JDQwMqLq6WufOnVMwGFRDQ4P6+vpUXFzsjJk9e7YKCwsVCoUkSaFQSHPnzpXf73fGlJSUKBqNOu/GBhOLxRSNRhM2AMD4N+xINTU1KTMzUx6PR/fee6927dqloqIihcNhpaenKysrK2G83+9XOByWJIXD4YRAXTh+4dgnqaqqks/nc7aCgoLhThsAMAYNO1LXXnutGhsbVV9frzVr1qi8vFzHjx8fibk5KisrFYlEnK2trW1Enw8AYIfU4Z6Qnp6ua665RpK0YMECHT58WI8//rjuvPNO9fb2qqurK+HdVHt7uwKBgCQpEAjo0KFDCY934e6/C2MG4/F45PF4hjtVAMAYd8W/JxWPxxWLxbRgwQKlpaWptrbWOdbS0qLW1lYFg0FJUjAYVFNTkzo6OpwxNTU18nq9KioqutKpAADGmWG9k6qsrNQtt9yiwsJCdXd3a8eOHXrttdf06quvyufzafXq1dq4caNycnLk9Xq1bt06BYNBLV26VJK0fPlyFRUVaeXKldq6davC4bAeeOABVVRU8E4JAHCJYUWqo6ND3/3ud/Xee+/J5/Np3rx5evXVV/W1r31NkvTYY4/J7XarrKxMsVhMJSUleuqpp5zzU1JStHv3bq1Zs0bBYFAZGRkqLy/XI488cnWvCgAwLriMMSbZkxiuaDQqn8+nSCQir9eb7OkAAIZpqK/j/O0+AIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANa6okht2bJFLpdLGzZscPb19PSooqJCubm5yszMVFlZmdrb2xPOa21tVWlpqSZNmqS8vDxt2rRJ/f39VzIVAMA4dNmROnz4sH71q19p3rx5Cfvvu+8+vfzyy9q5c6fq6up0+vRp3XHHHc7xgYEBlZaWqre3VwcPHtRzzz2n7du3a/PmzZd/FQCA8clchu7ubjNr1ixTU1NjbrrpJrN+/XpjjDFdXV0mLS3N7Ny50xn79ttvG0kmFAoZY4zZs2ePcbvdJhwOO2O2bdtmvF6vicViQ3r+SCRiJJlIJHI50wcAJNlQX8cv651URUWFSktLVVxcnLC/oaFBfX19Cftnz56twsJChUIhSVIoFNLcuXPl9/udMSUlJYpGo2pubh70+WKxmKLRaMIGABj/Uod7QnV1td58800dPnz4kmPhcFjp6enKyspK2O/3+xUOh50xHw/UheMXjg2mqqpKDz/88HCnCgAY44b1TqqtrU3r16/Xb37zG02YMGGk5nSJyspKRSIRZ2traxu15wYAJM+wItXQ0KCOjg7dcMMNSk1NVWpqqurq6vTEE08oNTVVfr9fvb296urqSjivvb1dgUBAkhQIBC652+/C1xfGXMzj8cjr9SZsAIDxb1iRWrZsmZqamtTY2OhsCxcu1IoVK5x/TktLU21trXNOS0uLWltbFQwGJUnBYFBNTU3q6OhwxtTU1Mjr9aqoqOgqXRYAYDwY1mdSkydP1pw5cxL2ZWRkKDc319m/evVqbdy4UTk5OfJ6vVq3bp2CwaCWLl0qSVq+fLmKioq0cuVKbd26VeFwWA888IAqKirk8Xiu0mUBAMaDYd848Vkee+wxud1ulZWVKRaLqaSkRE899ZRzPCUlRbt379aaNWsUDAaVkZGh8vJyPfLII1d7KgCAMc5ljDHJnsRwRaNR+Xw+RSIRPp8CgDFoqK/j/O0+AIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAaxEpAIC1UpM9gcthjJEkRaPRJM8EAHA5Lrx+X3g9/yRjMlJnzpyRJBUUFCR5JgCAK9Hd3S2fz/eJx8dkpHJyciRJra2tn3pxn3fRaFQFBQVqa2uT1+tN9nSsxToNDes0NKzT0Bhj1N3drfz8/E8dNyYj5XZ/9FGaz+fjm2AIvF4v6zQErNPQsE5Dwzp9tqG8yeDGCQCAtYgUAMBaYzJSHo9HDz30kDweT7KnYjXWaWhYp6FhnYaGdbq6XOaz7v8DACBJxuQ7KQDA5wORAgBYi0gBAKxFpAAA1hqTkXryySc1Y8YMTZgwQUuWLNGhQ4eSPaVRdeDAAd16663Kz8+Xy+XSiy++mHDcGKPNmzdr2rRpmjhxooqLi/XOO+8kjOns7NSKFSvk9XqVlZWl1atX6+zZs6N4FSOrqqpKixYt0uTJk5WXl6fbb79dLS0tCWN6enpUUVGh3NxcZWZmqqysTO3t7QljWltbVVpaqkmTJikvL0+bNm1Sf3//aF7KiNq2bZvmzZvn/OJpMBjU3r17neOs0eC2bNkil8ulDRs2OPtYqxFixpjq6mqTnp5ufv3rX5vm5mZz9913m6ysLNPe3p7sqY2aPXv2mJ/85CfmhRdeMJLMrl27Eo5v2bLF+Hw+8+KLL5q//OUv5pvf/KaZOXOmOX/+vDPm5ptvNvPnzzdvvPGG+dOf/mSuueYac9ddd43ylYyckpIS8+yzz5pjx46ZxsZG8/Wvf90UFhaas2fPOmPuvfdeU1BQYGpra82RI0fM0qVLzZe+9CXneH9/v5kzZ44pLi42R48eNXv27DFTpkwxlZWVybikEfH73//e/OEPfzB//etfTUtLi/nxj39s0tLSzLFjx4wxrNFgDh06ZGbMmGHmzZtn1q9f7+xnrUbGmIvU4sWLTUVFhfP1wMCAyc/PN1VVVUmcVfJcHKl4PG4CgYB59NFHnX1dXV3G4/GY559/3hhjzPHjx40kc/jwYWfM3r17jcvlMqdOnRq1uY+mjo4OI8nU1dUZYz5ak7S0NLNz505nzNtvv20kmVAoZIz56P8MuN1uEw6HnTHbtm0zXq/XxGKx0b2AUZSdnW2eeeYZ1mgQ3d3dZtasWaampsbcdNNNTqRYq5Ezpn7c19vbq4aGBhUXFzv73G63iouLFQqFkjgze5w8eVLhcDhhjXw+n5YsWeKsUSgUUlZWlhYuXOiMKS4ultvtVn19/ajPeTREIhFJ//+PEzc0NKivry9hnWbPnq3CwsKEdZo7d678fr8zpqSkRNFoVM3NzaM4+9ExMDCg6upqnTt3TsFgkDUaREVFhUpLSxPWROL7aSSNqT8w+8EHH2hgYCDhf2RJ8vv9OnHiRJJmZZdwOCxJg67RhWPhcFh5eXkJx1NTU5WTk+OMGU/i8bg2bNigG2+8UXPmzJH00Rqkp6crKysrYezF6zTYOl44Nl40NTUpGAyqp6dHmZmZ2rVrl4qKitTY2MgafUx1dbXefPNNHT58+JJjfD+NnDEVKeByVFRU6NixY3r99deTPRUrXXvttWpsbFQkEtHvfvc7lZeXq66uLtnTskpbW5vWr1+vmpoaTZgwIdnT+VwZUz/umzJlilJSUi65Y6a9vV2BQCBJs7LLhXX4tDUKBALq6OhION7f36/Ozs5xt45r167V7t27tX//fk2fPt3ZHwgE1Nvbq66uroTxF6/TYOt44dh4kZ6ermuuuUYLFixQVVWV5s+fr8cff5w1+piGhgZ1dHTohhtuUGpqqlJTU1VXV6cnnnhCqamp8vv9rNUIGVORSk9P14IFC1RbW+vsi8fjqq2tVTAYTOLM7DFz5kwFAoGENYpGo6qvr3fWKBgMqqurSw0NDc6Yffv2KR6Pa8mSJaM+55FgjNHatWu1a9cu7du3TzNnzkw4vmDBAqWlpSWsU0tLi1pbWxPWqampKSHoNTU18nq9KioqGp0LSYJ4PK5YLMYafcyyZcvU1NSkxsZGZ1u4cKFWrFjh/DNrNUKSfefGcFVXVxuPx2O2b99ujh8/bu655x6TlZWVcMfMeNfd3W2OHj1qjh49aiSZn/3sZ+bo0aPmH//4hzHmo1vQs7KyzEsvvWTeeustc9tttw16C/r1119v6uvrzeuvv25mzZo1rm5BX7NmjfH5fOa1114z7733nrN9+OGHzph7773XFBYWmn379pkjR46YYDBogsGgc/zCLcPLly83jY2N5pVXXjFTp04dV7cM33///aaurs6cPHnSvPXWW+b+++83LpfL/PGPfzTGsEaf5uN39xnDWo2UMRcpY4z5xS9+YQoLC016erpZvHixeeONN5I9pVG1f/9+I+mSrby83Bjz0W3oDz74oPH7/cbj8Zhly5aZlpaWhMc4c+aMueuuu0xmZqbxer1m1apVpru7OwlXMzIGWx9J5tlnn3XGnD9/3vzgBz8w2dnZZtKkSeZb3/qWee+99xIe59133zW33HKLmThxopkyZYr54Q9/aPr6+kb5akbO97//ffOFL3zBpKenm6lTp5ply5Y5gTKGNfo0F0eKtRoZ/Kc6AADWGlOfSQEAPl+IFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsNb/AyO+I4/63Da7AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "#打印游戏\n",
    "def show():\n",
    "    plt.imshow(env.render())\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "show()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 网络定义\n",
    "- 这里由于动作空间是-2，2的连续变化值。\n",
    "- 因此要想预测执行一个动作的概率，那么就预测执行这个动作背后的那个正态分布\n",
    "- 因此这里model的输出维度是2，代表mu和std"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((tensor([[0.9401],\n",
       "          [0.1030]], grad_fn=<MulBackward0>),\n",
       "  tensor([[0.7178],\n",
       "          [0.6060]], grad_fn=<SoftplusBackward0>)),\n",
       " tensor([[-0.1871],\n",
       "         [-0.4946]], grad_fn=<AddmmBackward0>))"
      ]
     },
     "execution_count": 76,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#定义模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc_statu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(3, 128),\n",
    "            torch.nn.ReLU(),\n",
    "        )\n",
    "\n",
    "        self.fc_mu = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Tanh(), #这里经过Tanh输出是-1，1\n",
    "        )\n",
    "\n",
    "        self.fc_std = torch.nn.Sequential(\n",
    "            torch.nn.Linear(128, 1),\n",
    "            torch.nn.Softplus(),\n",
    "        )\n",
    "\n",
    "    def forward(self, state):\n",
    "        state = self.fc_statu(state)\n",
    "\n",
    "        mu = self.fc_mu(state) * 2.0 #*2代表把[-1~1] --> [-2~2]\n",
    "        std = self.fc_std(state)\n",
    "\n",
    "        return mu, std\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model_td = torch.nn.Sequential(\n",
    "    torch.nn.Linear(3, 128),\n",
    "    torch.nn.ReLU(),\n",
    "    torch.nn.Linear(128, 1),\n",
    ")\n",
    "\n",
    "model(torch.randn(2, 3)), model_td(torch.randn(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-0.029393106698989868"
      ]
     },
     "execution_count": 77,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "def get_action(state):\n",
    "    state = torch.FloatTensor(state).reshape(1, 3)\n",
    "    mu, std = model(state)\n",
    "\n",
    "    #根据概率选择一个动作\n",
    "    #action = random.normalvariate(mu=mu.item(), sigma=std.item())\n",
    "    action = torch.distributions.Normal(mu, std).sample().item()\n",
    "\n",
    "    return action\n",
    "\n",
    "\n",
    "get_action([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.9445,  0.3285, -0.1006],\n",
       "         [-0.9496,  0.3135,  0.3173],\n",
       "         [-0.9581,  0.2865,  0.5643],\n",
       "         [-0.9689,  0.2473,  0.8149],\n",
       "         [-0.9791,  0.2035,  0.8981],\n",
       "         [-0.9890,  0.1481,  1.1271],\n",
       "         [-0.9958,  0.0914,  1.1409],\n",
       "         [-0.9995,  0.0319,  1.1931],\n",
       "         [-0.9997, -0.0258,  1.1540],\n",
       "         [-0.9964, -0.0847,  1.1812],\n",
       "         [-0.9884, -0.1518,  1.3501],\n",
       "         [-0.9764, -0.2160,  1.3064],\n",
       "         [-0.9604, -0.2786,  1.2932],\n",
       "         [-0.9444, -0.3288,  1.0548],\n",
       "         [-0.9294, -0.3692,  0.8606],\n",
       "         [-0.9166, -0.3998,  0.6639],\n",
       "         [-0.9042, -0.4271,  0.6006],\n",
       "         [-0.8957, -0.4446,  0.3886],\n",
       "         [-0.8959, -0.4443, -0.0066],\n",
       "         [-0.9046, -0.4262, -0.4020],\n",
       "         [-0.9182, -0.3960, -0.6626],\n",
       "         [-0.9338, -0.3578, -0.8257],\n",
       "         [-0.9499, -0.3124, -0.9627],\n",
       "         [-0.9659, -0.2587, -1.1211],\n",
       "         [-0.9803, -0.1977, -1.2540],\n",
       "         [-0.9910, -0.1338, -1.2972],\n",
       "         [-0.9980, -0.0637, -1.4077],\n",
       "         [-1.0000,  0.0038, -1.3519],\n",
       "         [-0.9982,  0.0594, -1.1127],\n",
       "         [-0.9938,  0.1108, -1.0307],\n",
       "         [-0.9875,  0.1574, -0.9408],\n",
       "         [-0.9820,  0.1890, -0.6419],\n",
       "         [-0.9778,  0.2095, -0.4193],\n",
       "         [-0.9764,  0.2162, -0.1355],\n",
       "         [-0.9776,  0.2104,  0.1183],\n",
       "         [-0.9815,  0.1912,  0.3912],\n",
       "         [-0.9851,  0.1719,  0.3935],\n",
       "         [-0.9892,  0.1467,  0.5096],\n",
       "         [-0.9934,  0.1150,  0.6400],\n",
       "         [-0.9973,  0.0729,  0.8457],\n",
       "         [-0.9996,  0.0296,  0.8672],\n",
       "         [-0.9998, -0.0191,  0.9743],\n",
       "         [-0.9978, -0.0666,  0.9512],\n",
       "         [-0.9933, -0.1155,  0.9814],\n",
       "         [-0.9856, -0.1689,  1.0794],\n",
       "         [-0.9765, -0.2155,  0.9490],\n",
       "         [-0.9671, -0.2545,  0.8042],\n",
       "         [-0.9569, -0.2903,  0.7441],\n",
       "         [-0.9492, -0.3147,  0.5113],\n",
       "         [-0.9438, -0.3306,  0.3351],\n",
       "         [-0.9410, -0.3384,  0.1673],\n",
       "         [-0.9430, -0.3328, -0.1202],\n",
       "         [-0.9462, -0.3235, -0.1976],\n",
       "         [-0.9547, -0.2975, -0.5455],\n",
       "         [-0.9650, -0.2623, -0.7342],\n",
       "         [-0.9760, -0.2179, -0.9147],\n",
       "         [-0.9866, -0.1633, -1.1121],\n",
       "         [-0.9938, -0.1113, -1.0502],\n",
       "         [-0.9983, -0.0590, -1.0510],\n",
       "         [-1.0000, -0.0068, -1.0430],\n",
       "         [-0.9991,  0.0433, -1.0025],\n",
       "         [-0.9960,  0.0897, -0.9311],\n",
       "         [-0.9909,  0.1344, -0.9001],\n",
       "         [-0.9848,  0.1739, -0.7984],\n",
       "         [-0.9789,  0.2043, -0.6198],\n",
       "         [-0.9740,  0.2265, -0.4544],\n",
       "         [-0.9707,  0.2404, -0.2861],\n",
       "         [-0.9683,  0.2497, -0.1912],\n",
       "         [-0.9719,  0.2353,  0.2960],\n",
       "         [-0.9769,  0.2136,  0.4451],\n",
       "         [-0.9827,  0.1853,  0.5780],\n",
       "         [-0.9884,  0.1522,  0.6721],\n",
       "         [-0.9922,  0.1243,  0.5630],\n",
       "         [-0.9959,  0.0906,  0.6783],\n",
       "         [-0.9982,  0.0598,  0.6181],\n",
       "         [-0.9996,  0.0266,  0.6634],\n",
       "         [-1.0000, -0.0062,  0.6562],\n",
       "         [-0.9994, -0.0346,  0.5696],\n",
       "         [-0.9979, -0.0654,  0.6156],\n",
       "         [-0.9955, -0.0949,  0.5919],\n",
       "         [-0.9930, -0.1180,  0.4659],\n",
       "         [-0.9899, -0.1415,  0.4726],\n",
       "         [-0.9863, -0.1652,  0.4800],\n",
       "         [-0.9805, -0.1966,  0.6384],\n",
       "         [-0.9731, -0.2304,  0.6932],\n",
       "         [-0.9665, -0.2568,  0.5436],\n",
       "         [-0.9595, -0.2817,  0.5172],\n",
       "         [-0.9550, -0.2965,  0.3097],\n",
       "         [-0.9534, -0.3018,  0.1105],\n",
       "         [-0.9536, -0.3011, -0.0136],\n",
       "         [-0.9558, -0.2940, -0.1494],\n",
       "         [-0.9615, -0.2749, -0.3985],\n",
       "         [-0.9678, -0.2518, -0.4797],\n",
       "         [-0.9784, -0.2065, -0.9304],\n",
       "         [-0.9884, -0.1518, -1.1123],\n",
       "         [-0.9953, -0.0965, -1.1151],\n",
       "         [-0.9991, -0.0419, -1.0946],\n",
       "         [-0.9999,  0.0155, -1.1482],\n",
       "         [-0.9978,  0.0662, -1.0148],\n",
       "         [-0.9935,  0.1138, -0.9562],\n",
       "         [-0.9885,  0.1514, -0.7588],\n",
       "         [-0.9834,  0.1817, -0.6135],\n",
       "         [-0.9782,  0.2078, -0.5329],\n",
       "         [-0.9744,  0.2248, -0.3484],\n",
       "         [-0.9736,  0.2283, -0.0707],\n",
       "         [-0.9743,  0.2251,  0.0654],\n",
       "         [-0.9779,  0.2090,  0.3289],\n",
       "         [-0.9828,  0.1848,  0.4936],\n",
       "         [-0.9898,  0.1426,  0.8569],\n",
       "         [-0.9964,  0.0846,  1.1669],\n",
       "         [-0.9998,  0.0193,  1.3086],\n",
       "         [-0.9990, -0.0437,  1.2605],\n",
       "         [-0.9941, -0.1084,  1.2969],\n",
       "         [-0.9845, -0.1755,  1.3568],\n",
       "         [-0.9692, -0.2462,  1.4474],\n",
       "         [-0.9537, -0.3007,  1.1327],\n",
       "         [-0.9377, -0.3475,  0.9888],\n",
       "         [-0.9254, -0.3791,  0.6786],\n",
       "         [-0.9189, -0.3944,  0.3315],\n",
       "         [-0.9163, -0.4005,  0.1333],\n",
       "         [-0.9166, -0.3999, -0.0126],\n",
       "         [-0.9223, -0.3865, -0.2928],\n",
       "         [-0.9331, -0.3595, -0.5801],\n",
       "         [-0.9478, -0.3190, -0.8624],\n",
       "         [-0.9636, -0.2672, -1.0841],\n",
       "         [-0.9795, -0.2017, -1.3479],\n",
       "         [-0.9915, -0.1301, -1.4519],\n",
       "         [-0.9983, -0.0589, -1.4315],\n",
       "         [-1.0000,  0.0067, -1.3121],\n",
       "         [-0.9973,  0.0729, -1.3251],\n",
       "         [-0.9911,  0.1328, -1.2051],\n",
       "         [-0.9827,  0.1852, -1.0620],\n",
       "         [-0.9720,  0.2348, -1.0138],\n",
       "         [-0.9637,  0.2668, -0.6619],\n",
       "         [-0.9576,  0.2883, -0.4462],\n",
       "         [-0.9556,  0.2946, -0.1321],\n",
       "         [-0.9567,  0.2912,  0.0704],\n",
       "         [-0.9624,  0.2716,  0.4096],\n",
       "         [-0.9687,  0.2481,  0.4859],\n",
       "         [-0.9769,  0.2138,  0.7046],\n",
       "         [-0.9857,  0.1683,  0.9285],\n",
       "         [-0.9939,  0.1100,  1.1767],\n",
       "         [-0.9989,  0.0477,  1.2504],\n",
       "         [-0.9995, -0.0307,  1.5674],\n",
       "         [-0.9936, -0.1133,  1.6577],\n",
       "         [-0.9816, -0.1910,  1.5718],\n",
       "         [-0.9647, -0.2635,  1.4908],\n",
       "         [-0.9476, -0.3195,  1.1697],\n",
       "         [-0.9320, -0.3624,  0.9128],\n",
       "         [-0.9212, -0.3891,  0.5761],\n",
       "         [-0.9150, -0.4035,  0.3138],\n",
       "         [-0.9148, -0.4039,  0.0094],\n",
       "         [-0.9218, -0.3877, -0.3520],\n",
       "         [-0.9344, -0.3563, -0.6762],\n",
       "         [-0.9482, -0.3176, -0.8222],\n",
       "         [-0.9624, -0.2717, -0.9614],\n",
       "         [-0.9766, -0.2149, -1.1726],\n",
       "         [-0.9880, -0.1545, -1.2292],\n",
       "         [-0.9958, -0.0911, -1.2764],\n",
       "         [-0.9996, -0.0269, -1.2863],\n",
       "         [-0.9995,  0.0318, -1.1744],\n",
       "         [-0.9959,  0.0908, -1.1836],\n",
       "         [-0.9895,  0.1442, -1.0758],\n",
       "         [-0.9821,  0.1881, -0.8904],\n",
       "         [-0.9770,  0.2132, -0.5109],\n",
       "         [-0.9737,  0.2278, -0.3007],\n",
       "         [-0.9724,  0.2335, -0.1157],\n",
       "         [-0.9738,  0.2274,  0.1248],\n",
       "         [-0.9774,  0.2116,  0.3237],\n",
       "         [-0.9810,  0.1938,  0.3627],\n",
       "         [-0.9862,  0.1658,  0.5707],\n",
       "         [-0.9909,  0.1345,  0.6330],\n",
       "         [-0.9952,  0.0974,  0.7469],\n",
       "         [-0.9985,  0.0544,  0.8624],\n",
       "         [-1.0000,  0.0047,  0.9941],\n",
       "         [-0.9991, -0.0417,  0.9281],\n",
       "         [-0.9968, -0.0801,  0.7694],\n",
       "         [-0.9915, -0.1303,  1.0093],\n",
       "         [-0.9834, -0.1816,  1.0393],\n",
       "         [-0.9725, -0.2329,  1.0492],\n",
       "         [-0.9598, -0.2807,  0.9886],\n",
       "         [-0.9510, -0.3093,  0.5987],\n",
       "         [-0.9435, -0.3312,  0.4641],\n",
       "         [-0.9416, -0.3368,  0.1171],\n",
       "         [-0.9443, -0.3291, -0.1618],\n",
       "         [-0.9514, -0.3078, -0.4494],\n",
       "         [-0.9627, -0.2704, -0.7817],\n",
       "         [-0.9750, -0.2224, -0.9919],\n",
       "         [-0.9857, -0.1687, -1.0938],\n",
       "         [-0.9941, -0.1083, -1.2200],\n",
       "         [-0.9992, -0.0402, -1.3669],\n",
       "         [-0.9997,  0.0235, -1.2750],\n",
       "         [-0.9964,  0.0844, -1.2198],\n",
       "         [-0.9896,  0.1438, -1.1948],\n",
       "         [-0.9798,  0.2000, -1.1412],\n",
       "         [-0.9709,  0.2396, -0.8125],\n",
       "         [-0.9624,  0.2716, -0.6610],\n",
       "         [-0.9584,  0.2855, -0.2898],\n",
       "         [-0.9564,  0.2921, -0.1387],\n",
       "         [-0.9567,  0.2911,  0.0214]]),\n",
       " tensor([[ -7.8810],\n",
       "         [ -7.9781],\n",
       "         [ -8.1600],\n",
       "         [ -8.4289],\n",
       "         [ -8.7047],\n",
       "         [ -9.0854],\n",
       "         [ -9.4328],\n",
       "         [ -9.8126],\n",
       "         [ -9.8415],\n",
       "         [ -9.4856],\n",
       "         [ -9.1181],\n",
       "         [ -8.7210],\n",
       "         [ -8.3426],\n",
       "         [ -7.9879],\n",
       "         [ -7.7112],\n",
       "         [ -7.5011],\n",
       "         [ -7.3281],\n",
       "         [ -7.2022],\n",
       "         [ -7.1889],\n",
       "         [ -7.3132],\n",
       "         [ -7.5217],\n",
       "         [ -7.7734],\n",
       "         [ -8.0669],\n",
       "         [ -8.4196],\n",
       "         [ -8.8165],\n",
       "         [ -9.2130],\n",
       "         [ -9.6716],\n",
       "         [-10.0308],\n",
       "         [ -9.6234],\n",
       "         [ -9.2908],\n",
       "         [ -8.9916],\n",
       "         [ -8.7526],\n",
       "         [ -8.6061],\n",
       "         [ -8.5504],\n",
       "         [ -8.5848],\n",
       "         [ -8.7139],\n",
       "         [ -8.8297],\n",
       "         [ -8.9921],\n",
       "         [ -9.2004],\n",
       "         [ -9.4881],\n",
       "         [ -9.7601],\n",
       "         [ -9.8448],\n",
       "         [ -9.5459],\n",
       "         [ -9.2536],\n",
       "         [ -8.9486],\n",
       "         [ -8.6424],\n",
       "         [ -8.3841],\n",
       "         [ -8.1609],\n",
       "         [ -7.9869],\n",
       "         [ -7.8778],\n",
       "         [ -7.8223],\n",
       "         [ -7.8559],\n",
       "         [ -7.9130],\n",
       "         [ -8.0925],\n",
       "         [ -8.3265],\n",
       "         [ -8.6215],\n",
       "         [ -8.9910],\n",
       "         [ -9.2919],\n",
       "         [ -9.6131],\n",
       "         [ -9.9356],\n",
       "         [ -9.7000],\n",
       "         [ -9.3999],\n",
       "         [ -9.1215],\n",
       "         [ -8.8659],\n",
       "         [ -8.6575],\n",
       "         [ -8.5069],\n",
       "         [ -8.4116],\n",
       "         [ -8.3554],\n",
       "         [ -8.4423],\n",
       "         [ -8.5831],\n",
       "         [ -8.7667],\n",
       "         [ -8.9804],\n",
       "         [ -9.1337],\n",
       "         [ -9.3545],\n",
       "         [ -9.5355],\n",
       "         [ -9.7469],\n",
       "         [ -9.8743],\n",
       "         [ -9.6858],\n",
       "         [ -9.5007],\n",
       "         [ -9.3168],\n",
       "         [ -9.1623],\n",
       "         [ -9.0208],\n",
       "         [ -8.8811],\n",
       "         [ -8.7081],\n",
       "         [ -8.5108],\n",
       "         [ -8.3360],\n",
       "         [ -8.1836],\n",
       "         [ -8.0783],\n",
       "         [ -8.0390],\n",
       "         [ -8.0416],\n",
       "         [ -8.0859],\n",
       "         [ -8.2139],\n",
       "         [ -8.3613],\n",
       "         [ -8.6926],\n",
       "         [ -9.0596],\n",
       "         [ -9.3965],\n",
       "         [ -9.7280],\n",
       "         [ -9.9049],\n",
       "         [ -9.5607],\n",
       "         [ -9.2579],\n",
       "         [ -8.9953],\n",
       "         [ -8.7929],\n",
       "         [ -8.6266],\n",
       "         [ -8.5090],\n",
       "         [ -8.4763],\n",
       "         [ -8.4956],\n",
       "         [ -8.6017],\n",
       "         [ -8.7628],\n",
       "         [ -9.0666],\n",
       "         [ -9.4810],\n",
       "         [ -9.9203],\n",
       "         [ -9.7557],\n",
       "         [ -9.3681],\n",
       "         [ -8.9784],\n",
       "         [ -8.5785],\n",
       "         [ -8.1723],\n",
       "         [ -7.8636],\n",
       "         [ -7.6240],\n",
       "         [ -7.4982],\n",
       "         [ -7.4532],\n",
       "         [ -7.4538],\n",
       "         [ -7.5425],\n",
       "         [ -7.7277],\n",
       "         [ -8.0095],\n",
       "         [ -8.3611],\n",
       "         [ -8.8167],\n",
       "         [ -9.2783],\n",
       "         [ -9.7091],\n",
       "         [ -9.9997],\n",
       "         [ -9.5923],\n",
       "         [ -9.1956],\n",
       "         [ -8.8468],\n",
       "         [ -8.5408],\n",
       "         [ -8.2893],\n",
       "         [ -8.1382],\n",
       "         [ -8.0820],\n",
       "         [ -8.1014],\n",
       "         [ -8.2348],\n",
       "         [ -8.3808],\n",
       "         [ -8.6119],\n",
       "         [ -8.9228],\n",
       "         [ -9.3276],\n",
       "         [ -9.7320],\n",
       "         [ -9.9241],\n",
       "         [ -9.4439],\n",
       "         [ -8.9465],\n",
       "         [ -8.4881],\n",
       "         [ -8.0691],\n",
       "         [ -7.7608],\n",
       "         [ -7.5517],\n",
       "         [ -7.4426],\n",
       "         [ -7.4303],\n",
       "         [ -7.5389],\n",
       "         [ -7.7595],\n",
       "         [ -8.0112],\n",
       "         [ -8.3088],\n",
       "         [ -8.6939],\n",
       "         [ -9.0706],\n",
       "         [ -9.4676],\n",
       "         [ -9.8672],\n",
       "         [ -9.8090],\n",
       "         [ -9.4466],\n",
       "         [ -9.0971],\n",
       "         [ -8.7981],\n",
       "         [ -8.5923],\n",
       "         [ -8.4873],\n",
       "         [ -8.4462],\n",
       "         [ -8.4826],\n",
       "         [ -8.5866],\n",
       "         [ -8.6954],\n",
       "         [ -8.8837],\n",
       "         [ -9.0804],\n",
       "         [ -9.3222],\n",
       "         [ -9.6054],\n",
       "         [ -9.9391],\n",
       "         [ -9.6962],\n",
       "         [ -9.4355],\n",
       "         [ -9.1685],\n",
       "         [ -8.8646],\n",
       "         [ -8.5586],\n",
       "         [ -8.2622],\n",
       "         [ -8.0292],\n",
       "         [ -7.8842],\n",
       "         [ -7.8308],\n",
       "         [ -7.8775],\n",
       "         [ -8.0220],\n",
       "         [ -8.2851],\n",
       "         [ -8.6095],\n",
       "         [ -8.9527],\n",
       "         [ -9.3484],\n",
       "         [ -9.8061],\n",
       "         [ -9.8849],\n",
       "         [ -9.4945],\n",
       "         [ -9.1268],\n",
       "         [ -8.7768],\n",
       "         [ -8.4740],\n",
       "         [ -8.2623],\n",
       "         [ -8.1430],\n",
       "         [ -8.0970],\n",
       "         [ -8.1013]]),\n",
       " tensor([[ 1.1434],\n",
       "         [ 0.0798],\n",
       "         [ 0.2376],\n",
       "         [-0.6814],\n",
       "         [ 0.5085],\n",
       "         [-0.6483],\n",
       "         [-0.1092],\n",
       "         [-0.4199],\n",
       "         [ 0.3106],\n",
       "         [ 1.5493],\n",
       "         [ 0.4677],\n",
       "         [ 0.9916],\n",
       "         [-0.1965],\n",
       "         [ 0.3494],\n",
       "         [ 0.5345],\n",
       "         [ 1.5771],\n",
       "         [ 0.7224],\n",
       "         [-0.4117],\n",
       "         [-0.4143],\n",
       "         [ 0.3941],\n",
       "         [ 0.8924],\n",
       "         [ 0.8755],\n",
       "         [ 0.5068],\n",
       "         [ 0.4073],\n",
       "         [ 0.7008],\n",
       "         [-0.0680],\n",
       "         [ 0.6902],\n",
       "         [ 1.5759],\n",
       "         [ 0.2491],\n",
       "         [ 0.0455],\n",
       "         [ 1.2058],\n",
       "         [ 0.5395],\n",
       "         [ 0.8443],\n",
       "         [ 0.6109],\n",
       "         [ 0.7675],\n",
       "         [-0.9405],\n",
       "         [-0.0856],\n",
       "         [ 0.1361],\n",
       "         [ 0.7959],\n",
       "         [-0.2205],\n",
       "         [ 0.5654],\n",
       "         [-0.0584],\n",
       "         [ 0.5347],\n",
       "         [ 1.2305],\n",
       "         [-0.0242],\n",
       "         [ 0.1117],\n",
       "         [ 0.8720],\n",
       "         [-0.1002],\n",
       "         [ 0.3987],\n",
       "         [ 0.5344],\n",
       "         [-0.2249],\n",
       "         [ 1.1480],\n",
       "         [-0.7020],\n",
       "         [ 0.2298],\n",
       "         [ 0.1078],\n",
       "         [-0.2264],\n",
       "         [ 1.2291],\n",
       "         [ 0.5514],\n",
       "         [ 0.3479],\n",
       "         [ 0.3041],\n",
       "         [ 0.2601],\n",
       "         [-0.2422],\n",
       "         [ 0.0058],\n",
       "         [ 0.3211],\n",
       "         [ 0.0812],\n",
       "         [-0.0106],\n",
       "         [-0.5694],\n",
       "         [ 2.0875],\n",
       "         [-0.1826],\n",
       "         [-0.1827],\n",
       "         [-0.2992],\n",
       "         [-1.4885],\n",
       "         [ 0.1473],\n",
       "         [-0.8541],\n",
       "         [ 0.0033],\n",
       "         [-0.1816],\n",
       "         [-0.5462],\n",
       "         [ 0.4796],\n",
       "         [ 0.1688],\n",
       "         [-0.3658],\n",
       "         [ 0.6355],\n",
       "         [ 0.7562],\n",
       "         [ 1.8820],\n",
       "         [ 1.3486],\n",
       "         [ 0.1546],\n",
       "         [ 1.1081],\n",
       "         [ 0.0247],\n",
       "         [ 0.1547],\n",
       "         [ 0.6821],\n",
       "         [ 0.6000],\n",
       "         [-0.1902],\n",
       "         [ 0.8330],\n",
       "         [-1.7456],\n",
       "         [-0.1801],\n",
       "         [ 0.7400],\n",
       "         [ 0.6191],\n",
       "         [-0.1481],\n",
       "         [ 0.8121],\n",
       "         [ 0.0596],\n",
       "         [ 0.7472],\n",
       "         [ 0.2111],\n",
       "         [-0.3708],\n",
       "         [ 0.1911],\n",
       "         [ 0.7270],\n",
       "         [-0.2338],\n",
       "         [ 0.6318],\n",
       "         [ 0.0527],\n",
       "         [ 1.4976],\n",
       "         [ 1.3541],\n",
       "         [ 0.5215],\n",
       "         [-0.4169],\n",
       "         [ 0.4612],\n",
       "         [ 0.9416],\n",
       "         [ 1.4817],\n",
       "         [-0.8666],\n",
       "         [ 0.5440],\n",
       "         [-0.3306],\n",
       "         [-0.4188],\n",
       "         [ 0.6508],\n",
       "         [ 1.0298],\n",
       "         [ 0.1317],\n",
       "         [ 0.0170],\n",
       "         [-0.0847],\n",
       "         [ 0.1174],\n",
       "         [-0.4228],\n",
       "         [ 0.3152],\n",
       "         [ 0.7864],\n",
       "         [ 1.0903],\n",
       "         [-0.1204],\n",
       "         [ 0.4357],\n",
       "         [ 0.2895],\n",
       "         [-0.6047],\n",
       "         [ 1.1721],\n",
       "         [ 0.1035],\n",
       "         [ 0.6529],\n",
       "         [-0.1227],\n",
       "         [ 0.8052],\n",
       "         [-0.8495],\n",
       "         [ 0.2177],\n",
       "         [ 0.4236],\n",
       "         [ 0.8134],\n",
       "         [-0.0588],\n",
       "         [ 1.8749],\n",
       "         [ 0.7553],\n",
       "         [-0.0064],\n",
       "         [ 0.4149],\n",
       "         [-0.8234],\n",
       "         [-0.1153],\n",
       "         [-0.4329],\n",
       "         [ 0.1966],\n",
       "         [-0.0118],\n",
       "         [-0.3898],\n",
       "         [-0.2231],\n",
       "         [ 0.8084],\n",
       "         [ 0.6607],\n",
       "         [-0.0495],\n",
       "         [ 0.6968],\n",
       "         [ 0.4576],\n",
       "         [ 0.3898],\n",
       "         [ 0.8807],\n",
       "         [-0.2207],\n",
       "         [ 0.2651],\n",
       "         [ 0.5144],\n",
       "         [ 1.5894],\n",
       "         [ 0.3358],\n",
       "         [ 0.0943],\n",
       "         [ 0.4362],\n",
       "         [ 0.1890],\n",
       "         [-0.7983],\n",
       "         [ 0.4174],\n",
       "         [-0.4131],\n",
       "         [ 0.0871],\n",
       "         [ 0.2827],\n",
       "         [ 0.6064],\n",
       "         [-0.4640],\n",
       "         [-0.8493],\n",
       "         [ 2.0223],\n",
       "         [ 0.8513],\n",
       "         [ 0.9736],\n",
       "         [ 0.7604],\n",
       "         [-1.1958],\n",
       "         [ 0.6492],\n",
       "         [-0.6571],\n",
       "         [-0.1757],\n",
       "         [-0.2715],\n",
       "         [-0.6765],\n",
       "         [-0.0488],\n",
       "         [ 0.4321],\n",
       "         [ 0.0027],\n",
       "         [-0.4376],\n",
       "         [ 0.8133],\n",
       "         [ 0.2507],\n",
       "         [-0.2556],\n",
       "         [-0.3614],\n",
       "         [ 1.1914],\n",
       "         [-0.1881],\n",
       "         [ 1.1170],\n",
       "         [-0.4199],\n",
       "         [-0.3931],\n",
       "         [-0.4036]]),\n",
       " tensor([[-0.9496,  0.3135,  0.3173],\n",
       "         [-0.9581,  0.2865,  0.5643],\n",
       "         [-0.9689,  0.2473,  0.8149],\n",
       "         [-0.9791,  0.2035,  0.8981],\n",
       "         [-0.9890,  0.1481,  1.1271],\n",
       "         [-0.9958,  0.0914,  1.1409],\n",
       "         [-0.9995,  0.0319,  1.1931],\n",
       "         [-0.9997, -0.0258,  1.1540],\n",
       "         [-0.9964, -0.0847,  1.1812],\n",
       "         [-0.9884, -0.1518,  1.3501],\n",
       "         [-0.9764, -0.2160,  1.3064],\n",
       "         [-0.9604, -0.2786,  1.2932],\n",
       "         [-0.9444, -0.3288,  1.0548],\n",
       "         [-0.9294, -0.3692,  0.8606],\n",
       "         [-0.9166, -0.3998,  0.6639],\n",
       "         [-0.9042, -0.4271,  0.6006],\n",
       "         [-0.8957, -0.4446,  0.3886],\n",
       "         [-0.8959, -0.4443, -0.0066],\n",
       "         [-0.9046, -0.4262, -0.4020],\n",
       "         [-0.9182, -0.3960, -0.6626],\n",
       "         [-0.9338, -0.3578, -0.8257],\n",
       "         [-0.9499, -0.3124, -0.9627],\n",
       "         [-0.9659, -0.2587, -1.1211],\n",
       "         [-0.9803, -0.1977, -1.2540],\n",
       "         [-0.9910, -0.1338, -1.2972],\n",
       "         [-0.9980, -0.0637, -1.4077],\n",
       "         [-1.0000,  0.0038, -1.3519],\n",
       "         [-0.9982,  0.0594, -1.1127],\n",
       "         [-0.9938,  0.1108, -1.0307],\n",
       "         [-0.9875,  0.1574, -0.9408],\n",
       "         [-0.9820,  0.1890, -0.6419],\n",
       "         [-0.9778,  0.2095, -0.4193],\n",
       "         [-0.9764,  0.2162, -0.1355],\n",
       "         [-0.9776,  0.2104,  0.1183],\n",
       "         [-0.9815,  0.1912,  0.3912],\n",
       "         [-0.9851,  0.1719,  0.3935],\n",
       "         [-0.9892,  0.1467,  0.5096],\n",
       "         [-0.9934,  0.1150,  0.6400],\n",
       "         [-0.9973,  0.0729,  0.8457],\n",
       "         [-0.9996,  0.0296,  0.8672],\n",
       "         [-0.9998, -0.0191,  0.9743],\n",
       "         [-0.9978, -0.0666,  0.9512],\n",
       "         [-0.9933, -0.1155,  0.9814],\n",
       "         [-0.9856, -0.1689,  1.0794],\n",
       "         [-0.9765, -0.2155,  0.9490],\n",
       "         [-0.9671, -0.2545,  0.8042],\n",
       "         [-0.9569, -0.2903,  0.7441],\n",
       "         [-0.9492, -0.3147,  0.5113],\n",
       "         [-0.9438, -0.3306,  0.3351],\n",
       "         [-0.9410, -0.3384,  0.1673],\n",
       "         [-0.9430, -0.3328, -0.1202],\n",
       "         [-0.9462, -0.3235, -0.1976],\n",
       "         [-0.9547, -0.2975, -0.5455],\n",
       "         [-0.9650, -0.2623, -0.7342],\n",
       "         [-0.9760, -0.2179, -0.9147],\n",
       "         [-0.9866, -0.1633, -1.1121],\n",
       "         [-0.9938, -0.1113, -1.0502],\n",
       "         [-0.9983, -0.0590, -1.0510],\n",
       "         [-1.0000, -0.0068, -1.0430],\n",
       "         [-0.9991,  0.0433, -1.0025],\n",
       "         [-0.9960,  0.0897, -0.9311],\n",
       "         [-0.9909,  0.1344, -0.9001],\n",
       "         [-0.9848,  0.1739, -0.7984],\n",
       "         [-0.9789,  0.2043, -0.6198],\n",
       "         [-0.9740,  0.2265, -0.4544],\n",
       "         [-0.9707,  0.2404, -0.2861],\n",
       "         [-0.9683,  0.2497, -0.1912],\n",
       "         [-0.9719,  0.2353,  0.2960],\n",
       "         [-0.9769,  0.2136,  0.4451],\n",
       "         [-0.9827,  0.1853,  0.5780],\n",
       "         [-0.9884,  0.1522,  0.6721],\n",
       "         [-0.9922,  0.1243,  0.5630],\n",
       "         [-0.9959,  0.0906,  0.6783],\n",
       "         [-0.9982,  0.0598,  0.6181],\n",
       "         [-0.9996,  0.0266,  0.6634],\n",
       "         [-1.0000, -0.0062,  0.6562],\n",
       "         [-0.9994, -0.0346,  0.5696],\n",
       "         [-0.9979, -0.0654,  0.6156],\n",
       "         [-0.9955, -0.0949,  0.5919],\n",
       "         [-0.9930, -0.1180,  0.4659],\n",
       "         [-0.9899, -0.1415,  0.4726],\n",
       "         [-0.9863, -0.1652,  0.4800],\n",
       "         [-0.9805, -0.1966,  0.6384],\n",
       "         [-0.9731, -0.2304,  0.6932],\n",
       "         [-0.9665, -0.2568,  0.5436],\n",
       "         [-0.9595, -0.2817,  0.5172],\n",
       "         [-0.9550, -0.2965,  0.3097],\n",
       "         [-0.9534, -0.3018,  0.1105],\n",
       "         [-0.9536, -0.3011, -0.0136],\n",
       "         [-0.9558, -0.2940, -0.1494],\n",
       "         [-0.9615, -0.2749, -0.3985],\n",
       "         [-0.9678, -0.2518, -0.4797],\n",
       "         [-0.9784, -0.2065, -0.9304],\n",
       "         [-0.9884, -0.1518, -1.1123],\n",
       "         [-0.9953, -0.0965, -1.1151],\n",
       "         [-0.9991, -0.0419, -1.0946],\n",
       "         [-0.9999,  0.0155, -1.1482],\n",
       "         [-0.9978,  0.0662, -1.0148],\n",
       "         [-0.9935,  0.1138, -0.9562],\n",
       "         [-0.9885,  0.1514, -0.7588],\n",
       "         [-0.9834,  0.1817, -0.6135],\n",
       "         [-0.9782,  0.2078, -0.5329],\n",
       "         [-0.9744,  0.2248, -0.3484],\n",
       "         [-0.9736,  0.2283, -0.0707],\n",
       "         [-0.9743,  0.2251,  0.0654],\n",
       "         [-0.9779,  0.2090,  0.3289],\n",
       "         [-0.9828,  0.1848,  0.4936],\n",
       "         [-0.9898,  0.1426,  0.8569],\n",
       "         [-0.9964,  0.0846,  1.1669],\n",
       "         [-0.9998,  0.0193,  1.3086],\n",
       "         [-0.9990, -0.0437,  1.2605],\n",
       "         [-0.9941, -0.1084,  1.2969],\n",
       "         [-0.9845, -0.1755,  1.3568],\n",
       "         [-0.9692, -0.2462,  1.4474],\n",
       "         [-0.9537, -0.3007,  1.1327],\n",
       "         [-0.9377, -0.3475,  0.9888],\n",
       "         [-0.9254, -0.3791,  0.6786],\n",
       "         [-0.9189, -0.3944,  0.3315],\n",
       "         [-0.9163, -0.4005,  0.1333],\n",
       "         [-0.9166, -0.3999, -0.0126],\n",
       "         [-0.9223, -0.3865, -0.2928],\n",
       "         [-0.9331, -0.3595, -0.5801],\n",
       "         [-0.9478, -0.3190, -0.8624],\n",
       "         [-0.9636, -0.2672, -1.0841],\n",
       "         [-0.9795, -0.2017, -1.3479],\n",
       "         [-0.9915, -0.1301, -1.4519],\n",
       "         [-0.9983, -0.0589, -1.4315],\n",
       "         [-1.0000,  0.0067, -1.3121],\n",
       "         [-0.9973,  0.0729, -1.3251],\n",
       "         [-0.9911,  0.1328, -1.2051],\n",
       "         [-0.9827,  0.1852, -1.0620],\n",
       "         [-0.9720,  0.2348, -1.0138],\n",
       "         [-0.9637,  0.2668, -0.6619],\n",
       "         [-0.9576,  0.2883, -0.4462],\n",
       "         [-0.9556,  0.2946, -0.1321],\n",
       "         [-0.9567,  0.2912,  0.0704],\n",
       "         [-0.9624,  0.2716,  0.4096],\n",
       "         [-0.9687,  0.2481,  0.4859],\n",
       "         [-0.9769,  0.2138,  0.7046],\n",
       "         [-0.9857,  0.1683,  0.9285],\n",
       "         [-0.9939,  0.1100,  1.1767],\n",
       "         [-0.9989,  0.0477,  1.2504],\n",
       "         [-0.9995, -0.0307,  1.5674],\n",
       "         [-0.9936, -0.1133,  1.6577],\n",
       "         [-0.9816, -0.1910,  1.5718],\n",
       "         [-0.9647, -0.2635,  1.4908],\n",
       "         [-0.9476, -0.3195,  1.1697],\n",
       "         [-0.9320, -0.3624,  0.9128],\n",
       "         [-0.9212, -0.3891,  0.5761],\n",
       "         [-0.9150, -0.4035,  0.3138],\n",
       "         [-0.9148, -0.4039,  0.0094],\n",
       "         [-0.9218, -0.3877, -0.3520],\n",
       "         [-0.9344, -0.3563, -0.6762],\n",
       "         [-0.9482, -0.3176, -0.8222],\n",
       "         [-0.9624, -0.2717, -0.9614],\n",
       "         [-0.9766, -0.2149, -1.1726],\n",
       "         [-0.9880, -0.1545, -1.2292],\n",
       "         [-0.9958, -0.0911, -1.2764],\n",
       "         [-0.9996, -0.0269, -1.2863],\n",
       "         [-0.9995,  0.0318, -1.1744],\n",
       "         [-0.9959,  0.0908, -1.1836],\n",
       "         [-0.9895,  0.1442, -1.0758],\n",
       "         [-0.9821,  0.1881, -0.8904],\n",
       "         [-0.9770,  0.2132, -0.5109],\n",
       "         [-0.9737,  0.2278, -0.3007],\n",
       "         [-0.9724,  0.2335, -0.1157],\n",
       "         [-0.9738,  0.2274,  0.1248],\n",
       "         [-0.9774,  0.2116,  0.3237],\n",
       "         [-0.9810,  0.1938,  0.3627],\n",
       "         [-0.9862,  0.1658,  0.5707],\n",
       "         [-0.9909,  0.1345,  0.6330],\n",
       "         [-0.9952,  0.0974,  0.7469],\n",
       "         [-0.9985,  0.0544,  0.8624],\n",
       "         [-1.0000,  0.0047,  0.9941],\n",
       "         [-0.9991, -0.0417,  0.9281],\n",
       "         [-0.9968, -0.0801,  0.7694],\n",
       "         [-0.9915, -0.1303,  1.0093],\n",
       "         [-0.9834, -0.1816,  1.0393],\n",
       "         [-0.9725, -0.2329,  1.0492],\n",
       "         [-0.9598, -0.2807,  0.9886],\n",
       "         [-0.9510, -0.3093,  0.5987],\n",
       "         [-0.9435, -0.3312,  0.4641],\n",
       "         [-0.9416, -0.3368,  0.1171],\n",
       "         [-0.9443, -0.3291, -0.1618],\n",
       "         [-0.9514, -0.3078, -0.4494],\n",
       "         [-0.9627, -0.2704, -0.7817],\n",
       "         [-0.9750, -0.2224, -0.9919],\n",
       "         [-0.9857, -0.1687, -1.0938],\n",
       "         [-0.9941, -0.1083, -1.2200],\n",
       "         [-0.9992, -0.0402, -1.3669],\n",
       "         [-0.9997,  0.0235, -1.2750],\n",
       "         [-0.9964,  0.0844, -1.2198],\n",
       "         [-0.9896,  0.1438, -1.1948],\n",
       "         [-0.9798,  0.2000, -1.1412],\n",
       "         [-0.9709,  0.2396, -0.8125],\n",
       "         [-0.9624,  0.2716, -0.6610],\n",
       "         [-0.9584,  0.2855, -0.2898],\n",
       "         [-0.9564,  0.2921, -0.1387],\n",
       "         [-0.9567,  0.2911,  0.0214],\n",
       "         [-0.9593,  0.2825,  0.1792]]),\n",
       " tensor([[0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [0],\n",
       "         [1]]))"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_data():\n",
    "    states = []\n",
    "    rewards = []\n",
    "    actions = []\n",
    "    next_states = []\n",
    "    overs = []\n",
    "\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        states.append(state)\n",
    "        rewards.append(reward)\n",
    "        actions.append(action)\n",
    "        next_states.append(next_state)\n",
    "        overs.append(over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "    #[b, 3]\n",
    "    states = torch.FloatTensor(states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    rewards = torch.FloatTensor(rewards).reshape(-1, 1)\n",
    "    #[b, 1]\n",
    "    actions = torch.FloatTensor(actions).reshape(-1, 1)\n",
    "    #[b, 3]\n",
    "    next_states = torch.FloatTensor(next_states).reshape(-1, 3)\n",
    "    #[b, 1]\n",
    "    overs = torch.LongTensor(overs).reshape(-1, 1)\n",
    "\n",
    "    return states, rewards, actions, next_states, overs\n",
    "\n",
    "\n",
    "get_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1224.31568119985"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from IPython import display\n",
    "import random\n",
    "\n",
    "def test(play):\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #记录反馈值的和,这个值越大越好\n",
    "    reward_sum = 0\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        state, reward, over, _ = env.step([action])\n",
    "        reward_sum += reward\n",
    "\n",
    "        #打印动画\n",
    "        if play and random.random() < 0.2:  #跳帧\n",
    "            display.clear_output(wait=True)\n",
    "            show()\n",
    "\n",
    "    return reward_sum\n",
    "\n",
    "\n",
    "test(play=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[5.43839184, 6.7140640000000005, 7.0544, 6.24, 4.0]"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#优势函数\n",
    "def get_advantages(deltas):\n",
    "    advantages = []\n",
    "\n",
    "    #反向遍历deltas\n",
    "    s = 0.0\n",
    "    for delta in deltas[::-1]:\n",
    "        s = 0.9 * 0.9 * s + delta\n",
    "        advantages.append(s)\n",
    "\n",
    "    #逆序\n",
    "    advantages.reverse()\n",
    "    return advantages\n",
    "\n",
    "\n",
    "get_advantages(range(5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 -417.788817209786\n",
      "200 -814.4478025370647\n",
      "400 -225.58088187103195\n",
      "600 -389.12576238812835\n",
      "800 -232.9130072375695\n",
      "1000 -265.50123287547024\n",
      "1200 -251.5885395787554\n",
      "1400 -215.8627665094043\n",
      "1600 -313.9082788119813\n",
      "1800 -254.42935434472673\n",
      "2000 -286.9488126255482\n",
      "2200 -185.12964854297937\n",
      "2400 -203.15024209350923\n",
      "2600 -300.55436286825454\n",
      "2800 -450.3713602416284\n"
     ]
    }
   ],
   "source": [
    "def train():\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n",
    "    optimizer_td = torch.optim.Adam(model_td.parameters(), lr=5e-3)\n",
    "    loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    #玩N局游戏,每局游戏训练M次\n",
    "    for epoch in range(3000):\n",
    "        #玩一局游戏,得到数据\n",
    "        #states -> [b, 3]\n",
    "        #rewards -> [b, 1]\n",
    "        #actions -> [b, 1]\n",
    "        #next_states -> [b, 3]\n",
    "        #overs -> [b, 1]\n",
    "        states, rewards, actions, next_states, overs = get_data()\n",
    "\n",
    "        #偏移reward,便于训练\n",
    "        rewards = (rewards + 8) / 8\n",
    "\n",
    "        #计算values和targets\n",
    "        #[b, 3] -> [b, 1]\n",
    "        values = model_td(states)\n",
    "\n",
    "        #[b, 3] -> [b, 1]\n",
    "        targets = model_td(next_states).detach()\n",
    "        targets = targets * 0.98\n",
    "        targets *= (1 - overs)\n",
    "        targets += rewards\n",
    "\n",
    "        #计算优势,这里的advantages有点像是策略梯度里的reward_sum\n",
    "        #只是这里计算的不是reward,而是target和value的差\n",
    "        #[b, 1]\n",
    "        deltas = (targets - values).squeeze(dim=1).tolist()\n",
    "        advantages = get_advantages(deltas)\n",
    "        advantages = torch.FloatTensor(advantages).reshape(-1, 1)\n",
    "\n",
    "        #取出每一步动作的概率\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        mu, std = model(states)\n",
    "        #[b, 1]\n",
    "        old_probs = torch.distributions.Normal(mu, std)\n",
    "        old_probs = old_probs.log_prob(actions).exp().detach()\n",
    "\n",
    "        #每批数据反复训练10次\n",
    "        for _ in range(10):\n",
    "            #重新计算每一步动作的概率\n",
    "            #[b, 3] -> [b, 1],[b, 1]\n",
    "            mu, std = model(states)\n",
    "            #[b, 1]\n",
    "            new_probs = torch.distributions.Normal(mu, std)\n",
    "            new_probs = new_probs.log_prob(actions).exp()\n",
    "\n",
    "            #求出概率的变化\n",
    "            #[b, 1] - [b, 1] -> [b, 1]\n",
    "            ratios = new_probs / old_probs\n",
    "\n",
    "            #计算截断的和不截断的两份loss,取其中小的\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr1 = ratios * advantages\n",
    "            #[b, 1] * [b, 1] -> [b, 1]\n",
    "            surr2 = torch.clamp(ratios, 0.8, 1.2) * advantages\n",
    "\n",
    "            loss = -torch.min(surr1, surr2)\n",
    "            loss = loss.mean()\n",
    "\n",
    "            #重新计算value,并计算时序差分loss\n",
    "            values = model_td(states)\n",
    "            loss_td = loss_fn(values, targets)\n",
    "\n",
    "            #更新参数\n",
    "            optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            optimizer_td.zero_grad()\n",
    "            loss_td.backward()\n",
    "            optimizer_td.step()\n",
    "\n",
    "        if epoch % 200 == 0:\n",
    "            test_result = sum([test(play=False) for _ in range(10)]) / 10\n",
    "            print(epoch, test_result)\n",
    "\n",
    "\n",
    "train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAakAAAGiCAYAAABd6zmYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAArSklEQVR4nO3de3hU9YH/8c9MLhNImIQEkiGSCBbL5eFS5TprV5aSElzqpaZdtVSzluqDDTwiLlvZKq4+drHgVutWsfu4q/58RPyxBSsIagQMKpFLIIrh4g1IuEwCZDO5kOvM9/eHZX5GgyRhkvmGvF/PM89jzjlz5nuOZt6eM2dOHMYYIwAALOSM9AAAADgXIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsFbEIvXUU09pyJAhiouL0+TJk7Vjx45IDQUAYKmIROqVV17RwoUL9eCDD2r37t0aN26csrOzVVFREYnhAAAs5YjEDWYnT56siRMn6o9//KMkKRgMKiMjQ/Pnz9d9993X3cMBAFgqurtfsKmpSUVFRVq8eHFomtPpVFZWlgoLC9t8TmNjoxobG0M/B4NBVVZWKiUlRQ6Ho8vHDAAIL2OMampqlJ6eLqfz3Cf1uj1Sp06dUiAQUFpaWqvpaWlpOnDgQJvPWbp0qR566KHuGB4AoBuVlZVp8ODB55zf7ZHqjMWLF2vhwoWhn/1+vzIzM1VWVia32x3BkQEAOqO6uloZGRnq16/fty7X7ZEaMGCAoqKiVF5e3mp6eXm5PB5Pm89xuVxyuVzfmO52u4kUAPRg5/vIptuv7ouNjdX48eO1adOm0LRgMKhNmzbJ6/V293AAABaLyOm+hQsXKjc3VxMmTNCkSZP0xBNPqK6uTrfffnskhgMAsFREInXTTTfp5MmTWrJkiXw+n773ve/pjTfe+MbFFACA3i0i35O6UNXV1UpMTJTf7+czKQDogdr7Ps69+wAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYq8OR2rp1q6699lqlp6fL4XDo1VdfbTXfGKMlS5Zo0KBB6tOnj7KysvTpp5+2WqayslKzZ8+W2+1WUlKS5syZo9ra2gvaEADAxafDkaqrq9O4ceP01FNPtTl/2bJlevLJJ/XMM89o+/btio+PV3Z2thoaGkLLzJ49WyUlJcrPz9f69eu1detW3XnnnZ3fCgDAxclcAElm7dq1oZ+DwaDxeDxm+fLloWlVVVXG5XKZl19+2RhjzL59+4wks3PnztAyGzduNA6Hwxw7dqxdr+v3+40k4/f7L2T4AIAIae/7eFg/kzp06JB8Pp+ysrJC0xITEzV58mQVFhZKkgoLC5WUlKQJEyaElsnKypLT6dT27dvbXG9jY6Oqq6tbPQAAF7+wRsrn80mS0tLSWk1PS0sLzfP5fEpNTW01Pzo6WsnJyaFlvm7p0qVKTEwMPTIyMsI5bACApXrE1X2LFy+W3+8PPcrKyiI9JABANwhrpDwejySpvLy81fTy8vLQPI/Ho4qKilbzW1paVFlZGVrm61wul9xud6sHAODiF9ZIDR06VB6PR5s2bQpNq66u1vbt2+X1eiVJXq9XVVVVKioqCi2zefNmBYNBTZ48OZzDAQD0cNEdfUJtba0+++yz0M+HDh1ScXGxkpOTlZmZqQULFuiRRx7R5ZdfrqFDh+qBBx5Qenq6brjhBknSyJEjNXPmTN1xxx165pln1NzcrHnz5unmm29Wenp62DYMAHAR6Ohlg1u2bDGSvvHIzc01xnx5GfoDDzxg0tLSjMvlMtOnTzcHDx5stY7Tp0+bW265xSQkJBi3221uv/12U1NTE/ZLFwEAdmrv+7jDGGMi2MhOqa6uVmJiovx+P59PAUAP1N738R5xdR8AoHciUgAAaxEpAIC1iBQAwFpECgBgLSIFALAWkQIAWItIAQCsRaQAANYiUgAAa3X4BrMAwi9QV6czX3yh+iNH1FJbK4fDoZjkZPUZMkR9hgyRMyYm0kMEIoJIARFkWlpU8/HH8q1dq4bSUrXU1Mg0N0uSnHFxina7FT98uNJvukmuSy6Rw8nJD/QuRAqIkJbaWp3cuFG+//t/FWxs/Mb8YH29murr1VRerrr9+5V+663qf9VVHFWhV+F/y4AICDY2quL111W+Zk2bgfq6ppMndez//B9Vbd+uHviHC4BOI1JANzPGqLq4WOVr1ihQV9fu5zWfOqVjL7ygxmPHunB0gF2IFNDNgvX1Kn36aQXr6zv83Kbych178UUFm5q6YGSAfYgU0M1Ovf22WmprO/38uoMHVbtvXxhHBNiLSAHd7Mznn4eu4OuM5spKNVVUhHFEgL2IFADAWkQK6IGMMVzlh16BSAE9kAkEIj0EoFsQKaAnCgYljqTQCxApoAfiSAq9BZECeiATCHAkhV6BSAE9kAkGIz0EoFsQKaAn4jMp9BJECuiBTCAgEoXegEgBPRCfSaG3IFJAD8RnUugtiBTQE3EkhV6CSAE9EN+TQm9BpIAeyHB1H3oJIgX0RBxJoZcgUkAPZIJB7oKOXoFIAT0Qn0mhtyBSQA/E96TQWxApoAfiwgn0FkQK6Ik43YdegkgBPRCn+9BbECmgBzLBIDeYRa9ApIAeqHr3bgXPnIn0MIAuR6SAHsg0N/M9KfQKRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECAFiLSAEArEWkAADWIlIAAGsRKQCAtYgUAMBaRAoAYC0iBQCwFpECullU376SwxHpYQA9ApECulm/cePkjIu78BUFgzLGXPh6AIsRKaCbOZzh+bUzgUBY1gPYrEO/LUuXLtXEiRPVr18/paam6oYbbtDBgwdbLdPQ0KC8vDylpKQoISFBOTk5Ki8vb7VMaWmpZs2apb59+yo1NVWLFi1SS0vLhW8N0BOE6VSf4XcGvUCHIlVQUKC8vDx98MEHys/PV3Nzs2bMmKG6urrQMvfcc4/WrVun1atXq6CgQMePH9eNN94Ymh8IBDRr1iw1NTVp27ZteuGFF/T8889ryZIl4dsqwGIOhyMsoeJICr2Bw1zASe2TJ08qNTVVBQUFuvrqq+X3+zVw4ECtXLlSP/nJTyRJBw4c0MiRI1VYWKgpU6Zo48aN+tGPfqTjx48rLS1NkvTMM8/o17/+tU6ePKnY2Njzvm51dbUSExPl9/vldrs7O3wgIvy7d+uLZcsUPHPmgtYz4t//XX2HDfsyekAP09738Qs6Oe73+yVJycnJkqSioiI1NzcrKysrtMyIESOUmZmpwsJCSVJhYaHGjBkTCpQkZWdnq7q6WiUlJW2+TmNjo6qrq1s9gB4rXKf7OJJCL9DpSAWDQS1YsEBXXXWVRo8eLUny+XyKjY1VUlJSq2XT0tLk8/lCy3w1UGfnn53XlqVLlyoxMTH0yMjI6OywgYhzOBwKR6aIFHqDTkcqLy9PH3/8sVatWhXO8bRp8eLF8vv9oUdZWVmXvybQZTiSAtotujNPmjdvntavX6+tW7dq8ODBoekej0dNTU2qqqpqdTRVXl4uj8cTWmbHjh2t1nf26r+zy3ydy+WSy+XqzFAB+zid4blwgqv70At0KFLGGM2fP19r167VO++8o6FDh7aaP378eMXExGjTpk3KycmRJB08eFClpaXyer2SJK/Xq9/+9reqqKhQamqqJCk/P19ut1ujRo0KxzYBVnM4HGoOBnWirk7H6+tVXl+vysZG1ba0qDEQUGZ8vLIvuUQJMTHfviKOpNALdChSeXl5Wrlypf7yl7+oX79+oc+QEhMT1adPHyUmJmrOnDlauHChkpOT5Xa7NX/+fHm9Xk2ZMkWSNGPGDI0aNUq33nqrli1bJp/Pp/vvv195eXkcLeGiZIxRc3Oz6urqtG/fPr2+apXeLijQiZoa1be0qCkYVEswqMBfl/2b1FRN9XjOGylO96E36FCkVqxYIUn6u7/7u1bTn3vuOf3jP/6jJOnxxx+X0+lUTk6OGhsblZ2draeffjq0bFRUlNavX6+77rpLXq9X8fHxys3N1cMPP3xhWwJYqLa2VsXFxSooKNCrr76qDz/8UIFAQLGSkl0uefr0UVJsrNyxseoXEyOX06nL+vVTn+jz/2oGiRR6gQv6nlSk8D0p2K6pqUk7duzQiy++qM2bN+vIkSMyxmjEiBGaMGKEMg8fVnpMjAa4XOrvcqlPVJSinE45pXZ/72nIvfcq+eqr+Z4UeqT2vo936sIJAG1rbm7WkSNHtHz5cq1bt06VlZXq37+/fvSjHyk3N1eTJk1SVHm5jv7udzJfuVNLZ3C6D70BkQLCwBijyspKrVmzRo8//rg++eQTDRo0SD/72c80Z84cTZw4UTF//YzpTG2tnA6HLjQxRAq9AZECLpAxRj6fT0uWLNHatWvl9/uVnZ2tuXPnaurUqerXr1/rU3J8TwpoNyIFXICWlhbt3r1bv/71r/XBBx9o4MCB+ud//mfdeeedcrvdcrb1Zzm4CzrQbkQK6CRjjDZt2qTFixdr7969+v73v6/77rtPP/jBD0Kn9triCNOXefmeFHoDIgV0gjFG7777ru69914dOHBAM2fO1O9+9zuNHDmy7aOnr+J0H9Bu/GVeoIMCgYC2bNmiefPm6dChQ8rNzdWKFSs0atSo8wdK7b/E/HyIFHoDIgV0gDFG+/fv1/3336/9+/crJydHDz30kDIyMtofnzD9+fgzn38elvUANiNSQDsZY1RTU6MlS5Zo586dmjp1qh5++OFWN1lulzD9Zd4zn356wesAbEekgHaqr6/X8uXL9eabb+p73/ueli1bpksvvbTD6wnX35MCegMiBbRDMBjUli1b9OKLLyo+Pl6LFi3SuHHjOvf5UphO9wG9Ab8tQDtUVlbqj3/8o44dO6af//znuvbaaxUVFdWpdXGvPaD9iBRwHsYYvfLKK3r77bc1btw4LVq0SH369On8CokU0G5ECjiPY8eO6T//8z+VkJCgX/3qVxo4cOCFrTBcX+YFegEiBXyLYDCoV155RZ9//rmuuuoqzZw5s9On+c7idB/QfkQK+BaHDh3Sa6+9pujoaM2ePVuDBg268MgQKaDdiBRwDsFgUEVFRfroo480bNgwZWdnh+coiEgB7UakgHNoaGjQhg0bVFNTo5tvvln9+/cPy3rDdoNZoBcgUsA5nDp1Sm+99ZYGDBig6667LnwrJlBAuxEp4ByKiop08uRJTZkyRampqWG74MFh4Zd5g8Ggjh49GulhAN9g328LYAFjjN544w1J0tVXX62EhITwrdzCI6lPP/1UTzzxhALcWR2WIVJAG+rq6rRr1y4NHDhQI0eOVHR0GP/0mmVHUsFgUOvWrdOGDRt0+PDhSA8HaMWu3xbAEkeOHNHp06c1aNCgjt/l/DwcklVHUydPntTWrVt1/PhxFRQUyBgT6SEBIUQKaMNnn32muro6eTwepaenh3flFh1JGWNUXFys3bt3q7a2Vlu2bFFNTU2khwWE8OfjgTYcPXpUTU1N8ng8SkxMDOu6z/enOo7V1WlPZaVqmps1MC5O3oEDFR8TE9YxnNXU1KR33nlHPp9Pxhjt2bNHn332ma644grujAErECmgDRUVFWppaZHH4wnv51HSOY+kjDE6VFurB/fs0eHaWjUEAnLHxGh0//56bOJExXTBEVhVVZVef/310Cm+L774Qh999JHGjRt3wbd/AsLBnvMOgEWqq6tljFFaWlr4V36Ov8z7RW2t7nj/fe33+1UfCMhI8jc36/2KCt29fbtONzSEfSjvvfee9u/fH/q5sbFR69evV0MXvBbQGUQKaEMwGJTT6Qz7qT7p3DeYfaKkRP7m5jbn7Th1SvnHj4d1HM3NzXr55ZcVDAZbTd+6datOnDgR1tcCOotIAefgcDgUFxfXFSsO/zo74cCBA9q1a9eXn5H9dUwOh0OVlZXavHkzV/nBCkQK+BZfP8oIC4dDfb/znfCvtwMCgYBee+011dXV6Sc/+YnGjx+vPn36aOrUqfJ4PFq9ejWn/GAFIgWcgzFGjY2NYV+vw+HQ4F/+8hvTZ2VkKOYcR1lDEhI0Njk5bGPw+Xz69NNPtWDBAv32t79Venq64uLidOedd+qxxx6TMUZ79+4N2+sBncXVfUAbnE6ngsGg/H5/l6w/JjFRydOmqXLLltC07L9+H+uRDz9UUyCgoKQoh0NJsbH694kTdenXbs2UlpPTqdc+G9/bbrtNU6ZMUU1NjRobGxUVFaXU1FT97d/+rS677DKdOXOm09sHhAuRAtrgdrvlcDhUXl7eJet39umj/l6v/Lt2KfDXL886HA5lp6drcN++Wn/0qE43NGhIQoJuGjpUKS5Xq+e7Bg1Sf6+3099lGjp0qIYOHSpJOnHihGpqahQdHa2UlBRFRUVp4sSJF7aBQJgQKaANqampio6Ols/nU0tLS9i/K+VwOOS+4goNvOYalf/5zzJ/vbGrw+HQ6P79Nfpb/nZVdP/+Sr/1VkW73Z1+7a+qqqpSVVWVXC6XBgwYwJd4YRU+kwLakJGRodjYWJWXl3fZKT+ny6W0H/9YyT/4gRztjGBUQoIG3XSTkiZNkiMMX7Y1xujUqVM6deqU+vfvrwEDBlzwOoFwIlJAG4YNG6b4+Hj5fD4dD/P3k74qqm9fDf7FL5SWk6PY1NRzLueIilKfIUOUcccdGnjNNXLGxobl9Y0xKisr0+nTpzVq1CjFdNHtl4DO4nQf0IbMzEylpKToxIkTOnr0qMaMGdMlr+NwOBQdH69BP/mJ3GPH6n+3bVNtSYkafT4FGxsVlZCguMGDlThhghInTFCfzMywno5rbGxUSUmJAoGAxo8fz6k+WIdIAW2Ij4/XxIkT9V//9V/av3+/fvjDH4b/Hn5f4XS5lDB6tOK/+10F6utlWlpkgkE5oqLkjImRs29fObvg9c+cOaNt27YpJiZGV111VdjXD1woTvcBbXA4HJo5c6Yk6d1331VdXV23vKbT5VJMUpJiBwyQKzVVsSkpina7uyRQklRaWqqSkhKNGDFCQ4YM6ZLXAC4EkQLOYcKECUpLS9MHH3ygkydPXnS3CTLG6I033lB9fb2uvvpqJXzte1iADYgUcA7JycmaOXOmTp48qfXr10d6OGHn9/u1du1a9e/fX9OmTeua+xQCF4hIAecQFxena665Rm63WytXrlRVVVWkhxQ2xhi9/fbb+vzzzzV27FhdeeWVXDQBKxEp4BycTqeuuOIKjRs3Tp988onefPPNi+KUnzFGVVVVWrdunRoaGjRt2jRdcsklkR4W0CYiBXyLIUOG6LrrrlNLS4tefPFFHT9+vMeHyhij999/X2+99ZY8Ho9mz57dpVcuAheCSAHfwul06qc//akuu+wybdu2TW+++aYCf72FUU9VW1urZ599VhUVFfrZz34WuocfYCMiBZzHJZdcojvvvFO1tbV6+umnVVFREekhdZoxRuvXr1d+fr6GDx+un//855EeEvCtiBRwHg6HQzfffLOysrL04Ycfavny5aqvr4/0sDrMGKOSkhItXbpU0dHRysvL02WXXRbpYQHfikgB7ZCcnKz58+frkksu0UsvvaR169b1uNN+p0+f1hNPPKHPPvtM06dP149//GPFxMRwVR+sRqSAdnA6nZo2bZpyc3NVV1en5cuXq7i4uMdcRNHc3KyXXnpJq1evVkpKiu677z4NGjQo0sMCzotIAe0UFxenf/qnf9LMmTNVXFysRYsW6dChQ5Ee1nk1Nzdr3bp1+td//VdFR0frkUce0YQJEziCQo9ApIB2cjgcSkhI0EMPPaSJEydq69ateuCBB1RaWhrpoZ1TS0uLNm3apN/85jcKBAK64447lJOTQ6DQYxApoAMcDodGjhypRx55RCNHjtTatWt1//3368iRI9ad+mtpadGWLVt077336vDhw/rpT3+qhQsXKiEhgUihxyBSQAdFRUVp2rRpevLJJ3XZZZdp5cqVmjt3rkpKShQMBiM9PElSXV2d/vznPys3N1eHDx/WP/zDP+j3v/+9Bg4cSKDQoxApoBMcDoemTp2qxx57TGPGjNHbb7+tefPmKT8/X83NzREblzFGlZWVeuKJJ3Tvvfequrpas2fP1iOPPCK3202g0OMQKaCTnE6nsrKy9Kc//Unf//73tX37dv3yl7/UE088of/93//t1qMqY4yampp04MABzZkzR48++qhqa2t1991369/+7d80ePBgAoUeyWFsO5HeDtXV1UpMTJTf75fb7Y70cNDLGWPk8/n04IMPas2aNfL7/ZoxY4bmzp2rqVOnql+/fl0aCGOMvvjiC61Zs0YrVqzQ0aNHNXz4cN1zzz267bbbFBUVRaBgnfa+jxMpIAzOnmZbs2aNHn/8cX3yyScaNGiQZsyYoV/+8pcaP368YmJiJCkswTj7a3vy5Em99tprWrlypXbs2KGmpibdcMMNWrBggSZMmKDY2NgLfi2gKxApIAKam5tVWlqqxx57TK+99poqKyuVmJioyZMnKzc3V5MmTVJSUpL69u0rp7PjZ9ubm5tVU1Mjn8+nV199VatWrdKhQ4dkjNF3v/td5eXlKScnR4mJiRw9wWpECoigpqYm7dy5Uy+99JLy8/NDl6gPHz5cXq9XkyZN0ne+8x2lp6crNTVV8fHxoVsUORwOGWMUCATU1NSkqqoq+Xw+HT16VPv379f777+v7du369SpU4qPj9fo0aN1/fXX65ZbblFGRgZxQo9ApAAL1NXVqbi4WO+8847Wrl2rDz/8UMFgUC6XS6mpqRo4cKAGDBiglJSU0BGWw+EIHTFVVlaqsrJSp06dUnl5uaqqqhQIBOR2uzVt2jTdcMMNmjx5sr7zne/wN6HQoxApwBLGGLW0tKi2tlb79u3TW2+9pXfffVdlZWWqq6tTfX29mpubFQgEZIyRMUYOh0NOp1NRUVFyuVzq27evkpKSNHLkSGVnZ+vqq6/WoEGDFBcXp6ioqEhvItBhRAqw1NnLxY8fP67Dhw+rrKxMp06dUnV1tRoaGhQMBhUVFaX4+HglJSXJ4/FoyJAhuvTSS5WSktKpz7IA27T3fZzzA0A3czgccrlcGjp0KH8VFziPDv0v2YoVKzR27Fi53W653W55vV5t3LgxNL+hoUF5eXlKSUlRQkKCcnJyVF5e3modpaWlmjVrlvr27avU1FQtWrRILS0t4dkaAMBFpUORGjx4sB599FEVFRVp165d+sEPfqDrr79eJSUlkqR77rlH69at0+rVq1VQUKDjx4/rxhtvDD0/EAho1qxZampq0rZt2/TCCy/o+eef15IlS8K7VQCAi4O5QP379zfPPvusqaqqMjExMWb16tWhefv37zeSTGFhoTHGmA0bNhin02l8Pl9omRUrVhi3220aGxvb/Zp+v99IMn6//0KHDwCIgPa+j3f6E9hAIKBVq1aprq5OXq9XRUVFam5uVlZWVmiZESNGKDMzU4WFhZKkwsJCjRkzRmlpaaFlsrOzVV1dHToaa0tjY6Oqq6tbPQAAF78OR2rv3r1KSEiQy+XS3LlztXbtWo0aNUo+n0+xsbFKSkpqtXxaWpp8Pp8kyefztQrU2fln553L0qVLlZiYGHpkZGR0dNgAgB6ow5EaPny4iouLtX37dt11113Kzc3Vvn37umJsIYsXL5bf7w89ysrKuvT1AAB26PAl6LGxsRo2bJgkafz48dq5c6f+8Ic/6KabbgrdwuWrR1Pl5eXyeDySJI/Hox07drRa39mr/84u0xaXyyWXy9XRoQIAergL/lZgMBhUY2Nj6C7PmzZtCs07ePCgSktL5fV6JUler1d79+5VRUVFaJn8/Hy53W6NGjXqQocCALjIdOhIavHixbrmmmuUmZmpmpoarVy5Uu+8847efPNNJSYmas6cOVq4cKGSk5Pldrs1f/58eb1eTZkyRZI0Y8YMjRo1SrfeequWLVsmn8+n+++/X3l5eRwpAQC+oUORqqio0G233aYTJ04oMTFRY8eO1Ztvvqkf/vCHkqTHH39cTqdTOTk5amxsVHZ2tp5++unQ86OiorR+/Xrddddd8nq9io+PV25urh5++OHwbhUA4KLAvfsAAN2uve/j3KkSAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLUuKFKPPvqoHA6HFixYEJrW0NCgvLw8paSkKCEhQTk5OSovL2/1vNLSUs2aNUt9+/ZVamqqFi1apJaWlgsZCgDgItTpSO3cuVN/+tOfNHbs2FbT77nnHq1bt06rV69WQUGBjh8/rhtvvDE0PxAIaNasWWpqatK2bdv0wgsv6Pnnn9eSJUs6vxUAgIuT6YSamhpz+eWXm/z8fDN16lRz9913G2OMqaqqMjExMWb16tWhZffv328kmcLCQmOMMRs2bDBOp9P4fL7QMitWrDBut9s0Nja26/X9fr+RZPx+f2eGDwCIsPa+j3fqSCovL0+zZs1SVlZWq+lFRUVqbm5uNX3EiBHKzMxUYWGhJKmwsFBjxoxRWlpaaJns7GxVV1erpKSkzddrbGxUdXV1qwcA4OIX3dEnrFq1Srt379bOnTu/Mc/n8yk2NlZJSUmtpqelpcnn84WW+Wqgzs4/O68tS5cu1UMPPdTRoQIAergOHUmVlZXp7rvv1ksvvaS4uLiuGtM3LF68WH6/P/QoKyvrttcGAEROhyJVVFSkiooKXXnllYqOjlZ0dLQKCgr05JNPKjo6WmlpaWpqalJVVVWr55WXl8vj8UiSPB7PN672O/vz2WW+zuVyye12t3oAAC5+HYrU9OnTtXfvXhUXF4ceEyZM0OzZs0P/HBMTo02bNoWec/DgQZWWlsrr9UqSvF6v9u7dq4qKitAy+fn5crvdGjVqVJg2CwBwMejQZ1L9+vXT6NGjW02Lj49XSkpKaPqcOXO0cOFCJScny+12a/78+fJ6vZoyZYokacaMGRo1apRuvfVWLVu2TD6fT/fff7/y8vLkcrnCtFkAgItBhy+cOJ/HH39cTqdTOTk5amxsVHZ2tp5++unQ/KioKK1fv1533XWXvF6v4uPjlZubq4cffjjcQwEA9HAOY4yJ9CA6qrq6WomJifL7/Xw+BQA9UHvfx7l3HwDAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLWIFADAWtGRHkBnGGMkSdXV1REeCQCgM86+f599Pz+XHhmp06dPS5IyMjIiPBIAwIWoqalRYmLiOef3yEglJydLkkpLS79143q76upqZWRkqKysTG63O9LDsRb7qX3YT+3DfmofY4xqamqUnp7+rcv1yEg5nV9+lJaYmMh/BO3gdrvZT+3Afmof9lP7sJ/Orz0HGVw4AQCwFpECAFirR0bK5XLpwQcflMvlivRQrMZ+ah/2U/uwn9qH/RReDnO+6/8AAIiQHnkkBQDoHYgUAMBaRAoAYC0iBQCwVo+M1FNPPaUhQ4YoLi5OkydP1o4dOyI9pG61detWXXvttUpPT5fD4dCrr77aar4xRkuWLNGgQYPUp08fZWVl6dNPP221TGVlpWbPni23262kpCTNmTNHtbW13bgVXWvp0qWaOHGi+vXrp9TUVN1www06ePBgq2UaGhqUl5enlJQUJSQkKCcnR+Xl5a2WKS0t1axZs9S3b1+lpqZq0aJFamlp6c5N6VIrVqzQ2LFjQ1889Xq92rhxY2g++6htjz76qBwOhxYsWBCaxr7qIqaHWbVqlYmNjTX//d//bUpKSswdd9xhkpKSTHl5eaSH1m02bNhgfvOb35g1a9YYSWbt2rWt5j/66KMmMTHRvPrqq+bDDz801113nRk6dKipr68PLTNz5kwzbtw488EHH5h3333XDBs2zNxyyy3dvCVdJzs72zz33HPm448/NsXFxebv//7vTWZmpqmtrQ0tM3fuXJORkWE2bdpkdu3aZaZMmWL+5m/+JjS/paXFjB492mRlZZk9e/aYDRs2mAEDBpjFixdHYpO6xGuvvWZef/1188knn5iDBw+af/mXfzExMTHm448/Nsawj9qyY8cOM2TIEDN27Fhz9913h6azr7pGj4vUpEmTTF5eXujnQCBg0tPTzdKlSyM4qsj5eqSCwaDxeDxm+fLloWlVVVXG5XKZl19+2RhjzL59+4wks3PnztAyGzduNA6Hwxw7dqzbxt6dKioqjCRTUFBgjPlyn8TExJjVq1eHltm/f7+RZAoLC40xX/7PgNPpND6fL7TMihUrjNvtNo2Njd27Ad2of//+5tlnn2UftaGmpsZcfvnlJj8/30ydOjUUKfZV1+lRp/uamppUVFSkrKys0DSn06msrCwVFhZGcGT2OHTokHw+X6t9lJiYqMmTJ4f2UWFhoZKSkjRhwoTQMllZWXI6ndq+fXu3j7k7+P1+Sf//5sRFRUVqbm5utZ9GjBihzMzMVvtpzJgxSktLCy2TnZ2t6upqlZSUdOPou0cgENCqVatUV1cnr9fLPmpDXl6eZs2a1WqfSPz31JV61A1mT506pUAg0OpfsiSlpaXpwIEDERqVXXw+nyS1uY/OzvP5fEpNTW01Pzo6WsnJyaFlLibBYFALFizQVVddpdGjR0v6ch/ExsYqKSmp1bJf309t7cez8y4We/fuldfrVUNDgxISErR27VqNGjVKxcXF7KOvWLVqlXbv3q2dO3d+Yx7/PXWdHhUpoDPy8vL08ccf67333ov0UKw0fPhwFRcXy+/363/+53+Um5urgoKCSA/LKmVlZbr77ruVn5+vuLi4SA+nV+lRp/sGDBigqKiob1wxU15eLo/HE6FR2eXsfvi2feTxeFRRUdFqfktLiyorKy+6/Thv3jytX79eW7Zs0eDBg0PTPR6PmpqaVFVV1Wr5r++ntvbj2XkXi9jYWA0bNkzjx4/X0qVLNW7cOP3hD39gH31FUVGRKioqdOWVVyo6OlrR0dEqKCjQk08+qejoaKWlpbGvukiPilRsbKzGjx+vTZs2haYFg0Ft2rRJXq83giOzx9ChQ+XxeFrto+rqam3fvj20j7xer6qqqlRUVBRaZvPmzQoGg5o8eXK3j7krGGM0b948rV27Vps3b9bQoUNbzR8/frxiYmJa7aeDBw+qtLS01X7au3dvq6Dn5+fL7XZr1KhR3bMhERAMBtXY2Mg++orp06dr7969Ki4uDj0mTJig2bNnh/6ZfdVFIn3lRketWrXKuFwu8/zzz5t9+/aZO++80yQlJbW6YuZiV1NTY/bs2WP27NljJJnf//73Zs+ePebIkSPGmC8vQU9KSjJ/+ctfzEcffWSuv/76Ni9Bv+KKK8z27dvNe++9Zy6//PKL6hL0u+66yyQmJpp33nnHnDhxIvQ4c+ZMaJm5c+eazMxMs3nzZrNr1y7j9XqN1+sNzT97yfCMGTNMcXGxeeONN8zAgQMvqkuG77vvPlNQUGAOHTpkPvroI3PfffcZh8Nh3nrrLWMM++jbfPXqPmPYV12lx0XKGGP+4z/+w2RmZprY2FgzadIk88EHH0R6SN1qy5YtRtI3Hrm5ucaYLy9Df+CBB0xaWppxuVxm+vTp5uDBg63Wcfr0aXPLLbeYhIQE43a7ze23325qamoisDVdo639I8k899xzoWXq6+vNr371K9O/f3/Tt29f8+Mf/9icOHGi1XoOHz5srrnmGtOnTx8zYMAAc++995rm5uZu3pqu84tf/MJceumlJjY21gwcONBMnz49FChj2Eff5uuRYl91Df5UBwDAWj3qMykAQO9CpAAA1iJSAABrESkAgLWIFADAWkQKAGAtIgUAsBaRAgBYi0gBAKxFpAAA1iJSAABrESkAgLX+H5Ah0Ij7+6xyAAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "-141.50604565392825"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test(play=True)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Gym",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
